# Copyright 2023 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""JAX bindings for Mosaic."""

# mypy: ignore-errors
from __future__ import annotations

import base64
import collections.abc
import dataclasses
import functools
import io
import os
import time
from typing import Any, Callable

import jax
from jax import core
from jax._src import config
from jax._src import sharding_impls
from jax._src.interpreters import mlir
from jax._src.lib import tpu
from jax._src.lib import xla_client
from jax._src.lib.mlir.dialects import hlo
from jax.interpreters import xla
from jaxlib.mlir import ir
from jaxlib.mlir.dialects import mhlo
from jaxlib.mlir.dialects import stablehlo
from jaxlib.mlir.passmanager import PassManager
import numpy as np

try:
  from absl import flags
  FLAGS = flags.FLAGS
except ImportError:
  FLAGS = {}

_MOSAIC_USE_PYTHON_PIPELINE = config.define_bool_state(
    name="mosaic_use_python_pipeline",
    default=False,
    help=(
        "Run the initial Mosaic MLIR passes from Python, when as_tpu_kernel"
        " is called (for Pallas, this happens at JAX lowering time), instead of"
        " later within XLA."
    ),
)

_MOSAIC_ALLOW_HLO = config.define_bool_state(
    name="jax_mosaic_allow_hlo",
    default=False,
    help="Allow hlo dialects in Mosaic",
)

tpu_custom_call_p = core.Primitive("tpu_custom_call")
tpu_custom_call_p.def_impl(
    functools.partial(xla.apply_primitive, tpu_custom_call_p))
tpu_custom_call_p.multiple_results = True


@dataclasses.dataclass(frozen=True)
class CostEstimate:
  flops: int
  transcendentals: int
  bytes_accessed: int

  def to_json(self) -> bytes:
    return (
        f'{{"flops": {self.flops}, "transcendentals": {self.transcendentals},'
        f' "bytes_accessed": {self.bytes_accessed}}}'
    ).encode('ascii')


@dataclasses.dataclass(frozen=True)
class CustomCallBackendConfig:
  """Represents an unserialized backend config for custom calls."""
  lowered_module_asm: bytes
  has_communication: bool
  collective_id: int | None
  device_type: str | None
  cost_estimate: CostEstimate | None
  needs_hlo_passes: bool
  needs_layout_passes: bool
  vmem_limit_bytes: int | None
  flags: dict[str, bool | int | float] | None
  allow_input_fusion: list[bool] | None
  serialization_format: int | None

  # We omit the body while printing, because primitive params get embedded
  # in HLO metadata, and the body blows up its size.
  def __repr__(self):
    return "CustomCallBackendConfig(<omitted>)"

  def to_json(self) -> bytes:
    """Serializes the backend config into JSON."""
    # We format the JSON ourselves, because json.dumps seems to be overly slow.
    config = io.BytesIO()
    config.write(b'{"custom_call_config": {"body": "')
    config.write(base64.b64encode(self.lowered_module_asm))
    config.write(b'"')
    if self.has_communication:
      config.write(b', "has_communication": ')
      config.write(str(self.has_communication).lower().encode("ascii"))
    if self.collective_id is not None:
      config.write(b', "collective_id": ')
      config.write(str(self.collective_id).encode("ascii"))
    if self.cost_estimate is not None:
      config.write(b', "cost_estimate": ')
      config.write(self.cost_estimate.to_json())
    if self.needs_hlo_passes:
      config.write(b', "needs_hlo_passes": ')
      config.write(str(self.needs_hlo_passes).lower().encode("ascii"))
    if self.serialization_format is not None:
      config.write(b', "serialization_format": ')
      config.write(str(self.serialization_format).lower().encode("ascii"))
    if self.needs_layout_passes:
      config.write(b', "needs_layout_passes": ')
      config.write(str(self.needs_layout_passes).lower().encode("ascii"))
    if self.allow_input_fusion is not None:
      config.write(b', "allow_input_fusion": [')
      for i, value in enumerate(self.allow_input_fusion):
        config.write(b"true" if value else b"false")
        # config.write(str(value).lower().encode("ascii"))
        if i + 1 != len(self.allow_input_fusion):
          config.write(b",")
      config.write(b"]")
    config.write(b"}")  # End of custom_call_config.
    if self.device_type is not None:
      config.write(b', "device_type": ')
      config.write(
          ('"DEVICE_TYPE_' + self.device_type.upper() + '"').encode("ascii")
      )
    if self.vmem_limit_bytes is not None:
      config.write(
          b', "scoped_memory_configs": [{"memory_space":1, "offset": 0,'
          b' "size": '
      )
      config.write(str(self.vmem_limit_bytes).encode("ascii"))
      config.write(b'}]')
    if self.flags is not None:
      config.write(b', "flag_configs": [')
      for i, (flag, value) in enumerate(self.flags.items()):
        config.write(b'{"flag_type": "')
        config.write(flag.encode("ascii"))
        config.write(b'", value: {')
        if isinstance(value, bool):
          config.write(b'"boolean_value": ')
          config.write(b"true" if value else b"false")
        elif isinstance(value, int):
          config.write(b'"integer_value": ')
          config.write(str(value).encode("ascii"))
        elif isinstance(value, float):
          config.write(b'"double_value": ')
          config.write(str(value).encode("ascii"))
        else:
          raise ValueError("invalid flag value: " + str(value))
        config.write(b"}}")
        if i + 1 != len(self.flags):
          config.write(b",")
      config.write(b"]")
    # Prevent the compiler from sharding the custom call beyond what Mosaic does
    # based on user annotations
    config.write(b', "implicit_sharding": {"type": "MANUAL"}')
    config.write(b"}")
    return config.getvalue()


@tpu_custom_call_p.def_abstract_eval
def _tpu_custom_call_abstract_eval(*_, out_avals, **__):
  return out_avals


def _aval_to_layout(aval):
  arange = np.arange(aval.ndim, dtype=np.dtype(np.int64))[::-1].copy()
  return ir.DenseIntElementsAttr.get(arange, type=ir.IndexType.get())


def _avals_to_layouts(avals):
  return ir.ArrayAttr.get([_aval_to_layout(a) for a in avals])


def _tpu_custom_call_lowering(
    ctx: mlir.LoweringRuleContext,
    *in_nodes,  # pylint: disable=missing-function-docstring
    config: CustomCallBackendConfig,
    kernel_name: str | None,
    kernel_regeneration_metadata: bytes | None,
    out_avals: Any,
    input_output_aliases: tuple[tuple[int, int], ...],
) -> ...:
  i32_type = ir.IntegerType.get_signless(32)
  multiple_results = len(out_avals) > 1
  if multiple_results:
    result_type = ir.TupleType.get_tuple(
        [mlir.aval_to_ir_type(aval) for aval in out_avals]
    )
  else:
    result_type = mlir.aval_to_ir_type(out_avals[0])
  axis_context = ctx.module_context.axis_context
  if isinstance(axis_context, sharding_impls.SPMDAxisContext):
    if axis_context.manual_axes != frozenset(axis_context.mesh.axis_names):
      raise NotImplementedError(
          "Mosaic kernels cannot be automatically partitioned. Please wrap the"
          " call in a shard_map or xmap."
      )
  elif isinstance(axis_context, sharding_impls.ShardingContext):
    if axis_context.num_devices != 1:
      raise NotImplementedError(
          "Mosaic kernels cannot be automatically partitioned. Please wrap the"
          " call in a shard_map or xmap."
      )
  elif config.has_communication:
    raise NotImplementedError(
        "Replica lowering for Mosaic kernels not implemented."
    )
  call = stablehlo.CustomCallOp(
      [result_type],
      in_nodes,
      call_target_name=ir.StringAttr.get(b"tpu_custom_call"),
      has_side_effect=ir.BoolAttr.get(False),
      backend_config=ir.StringAttr.get(config.to_json()),
      api_version=ir.IntegerAttr.get(i32_type, 1),
      called_computations=ir.ArrayAttr.get([]),
      operand_layouts=_avals_to_layouts(ctx.avals_in),
      result_layouts=_avals_to_layouts(ctx.avals_out),
      output_operand_aliases=ir.ArrayAttr.get([
          hlo.OutputOperandAlias.get(
              # if len(result_types) == 1 then the aliasing refers implicitly to
              # the only output.
              output_tuple_indices=[output_idx]
              if len(out_avals) > 1
              else [],
              operand_index=input_idx,
              operand_tuple_indices=[],
          )
          for input_idx, output_idx in input_output_aliases
      ]),
  )

  # Add kernel_name and kernel_regeneration_metadata as attributes to the
  # custom call op. This is because we do not want to pollute the backend_config
  # with this information.
  if kernel_name is not None:
    call.attributes["kernel_name"] = ir.StringAttr.get(kernel_name)
  if kernel_regeneration_metadata is not None:
    call.attributes["kernel_regeneration_metadata"] = ir.StringAttr.get(
        base64.b64encode(kernel_regeneration_metadata)
    )
  if multiple_results:
    results = [stablehlo.get_tuple_element(call, mlir.i32_attr(i))
               for i in range(len(out_avals))]
  else:
    results = call.results
  return results


mlir.register_lowering(tpu_custom_call_p, _tpu_custom_call_lowering,
                       platform="tpu")


def _lower_tpu_kernel(
    module: ir.Module,
    hardware_generation: int,
) -> ir.Module:
  """Runs MLIR passes lowering the given module to an MLIR module.

  Uses Python versions of infer-memref-layout and apply-vector-layout.

  Args:
    module: The MLIR module to lower.
    hardware_generation: The TPU hardware generation to target.

  Returns:
    An MLIR module implementing the kernel.
  """
  try:
    module.operation.verify()
  except ir.MLIRError as e:
    raise ValueError("The compiled module fails MLIR verification") from e

  with module.context as ctx, module.operation.location as _:
    ctx.append_dialect_registry(mlir.upstream_dialects)
    ctx.load_all_available_dialects()
    tpu.register_dialect(ctx)
    mhlo.register_mhlo_dialect(ctx)
    mhlo.register_mhlo_passes()

    dump_mlir(module, "original")

    if _MOSAIC_ALLOW_HLO.value:
      # Run hlo dialect conversion: hlo -> linalg -> vector.
      pipeline = [
          "hlo-legalize-to-arithmetic",
          "func.func(hlo-legalize-to-linalg)",
          "func.func(linalg-vectorization)",
      ]
      pipeline = PassManager.parse(f"builtin.module({','.join(pipeline)})")
      pipeline.run(module.operation)
      dump_mlir(module, "post-hlo-conversion")

    pipeline = [
        f"func.func(tpu-infer-memref-layout{{hardware-generation={hardware_generation}}})"
    ]
    pipeline = PassManager.parse(f"builtin.module({','.join(pipeline)})")
    pipeline.run(module.operation)
    dump_mlir(module, "post-infer-memref-layout")

    pipeline = [
        "canonicalize",
        "cse",
    ]
    pipeline = PassManager.parse(f"builtin.module({','.join(pipeline)})")
    pipeline.run(module.operation)
    dump_mlir(module, "post-simplify")

    try:
      on_device_checks = FLAGS["xla_mosaic_on_device_checks"].value
    except KeyError:
      on_device_checks = False

    if checks := on_device_checks:
      checks = set(checks.split(","))
      if checks == {"bounds"}:  # We only support one kind of checks now.
        pipeline = PassManager.parse(
            "builtin.module(func.func(debug-assert-insertion))"
        )
        pipeline.run(module.operation)
        dump_mlir(module, "post-assert-insertion")
      elif checks:
        checks.discard("bounds")
        raise ValueError(
            f"Unrecognized on-device check categories: {', '.join(checks)}"
        )

    pipeline = [
        "func.func(tpu-infer-vector-layout{sublane-count=8 lane-count=128})",
    ]
    pipeline = PassManager.parse(f"builtin.module({','.join(pipeline)})")
    pipeline.run(module.operation)
    dump_mlir(module, "post-infer-vector-layout")

    mxu_size = 128 if hardware_generation < 6 else 256
    pipeline = [
        "func.func(tpu-apply-vector-layout{sublane-count=8 lane-count=128"
        f" hardware-generation={hardware_generation}"
        f" mxu-contracting-size={mxu_size} mxu-noncontracting-size={mxu_size}"
        "})"
    ]
    pipeline = PassManager.parse(f"builtin.module({','.join(pipeline)})")
    pipeline.run(module.operation)
    dump_mlir(module, "post-apply-vector-layout")

    pipeline = PassManager.parse("builtin.module(canonicalize)")
    pipeline.run(module.operation)
    dump_mlir(module, "pre-lower-to-llo")

    return module


def as_tpu_kernel(
    module: ir.Module,
    out_type: Any,
    *,
    cost_estimate: CostEstimate | None = None,
    backend: str | xla_client.Client = "tpu",
    device_type: str | None = None,
    kernel_name: str | None = None,
    kernel_regeneration_metadata: bytes | None = None,
    vmem_limit_bytes: int | None = None,
    flags: dict[str, bool | int | float] | None = None,
    allow_input_fusion: list[bool] | None = None,
    input_output_aliases: tuple[tuple[int, int], ...] = (),
) -> Callable[..., Any]:
  """Turns an MLIR Mosaic kernel into a JAX-compatible function."""
  # We use jax.jit to make sure we hit the fast compilation cache.
  some_tpu = jax.devices(backend)[0]
  device_kind = some_tpu.device_kind
  if not device_kind.startswith("TPU v"):
    raise ValueError(f"Unrecognized TPU device kind: {device_kind}.")
  if vmem_limit_bytes is not None and not isinstance(vmem_limit_bytes, int):
    raise ValueError(
        "vmem_limit_bytes must be an int: provided with a"
        f" {type(vmem_limit_bytes)}."
    )
  hardware_generation = int(device_kind[len("TPU v")])
  has_communication, has_custom_barrier = tpu.private_has_communication(
      module.operation
  )
  needs_hlo_passes = _MOSAIC_ALLOW_HLO.value
  needs_layout_passes = not device_type
  # We'll mutate the module, so clone it
  with module.context as ctx, module.operation.location as _:
    module = ir.Module.parse(
        module.operation.get_asm(binary=True, enable_debug_info=True)
    )
    if needs_layout_passes and _MOSAIC_USE_PYTHON_PIPELINE.value:
      module = _lower_tpu_kernel(module, hardware_generation)
      needs_hlo_passes = False
      needs_layout_passes = False
    prev_allow_unregistered_dialects = ctx.allow_unregistered_dialects
    ctx.allow_unregistered_dialects = True
    try:
      pipeline = PassManager.parse("builtin.module(mosaic-serde{serialize=true})")
      pipeline.run(module.operation)
    finally:
      ctx.allow_unregistered_dialects = prev_allow_unregistered_dialects
    bytecode_buffer = io.BytesIO()
    module.operation.write_bytecode(bytecode_buffer, desired_version=0)
    asm = bytecode_buffer.getvalue()

  # TODO(amagni): Kernel name and regeneration metadata could alternatively be
  # added as a custom attribute to the MLIR call op rather than including them
  # in the backend_config.
  return _lowered_as_tpu_kernel(
      asm,
      out_type,
      needs_hlo_passes=needs_hlo_passes,
      needs_layout_passes=needs_layout_passes,
      device_type=device_type,
      has_communication=has_communication,
      has_custom_barrier=has_custom_barrier,
      kernel_name=kernel_name,
      kernel_regeneration_metadata=kernel_regeneration_metadata,
      cost_estimate=cost_estimate,
      vmem_limit_bytes=vmem_limit_bytes,
      flags=flags,
      allow_input_fusion=allow_input_fusion,
      input_output_aliases=input_output_aliases,
  )


def _lowered_as_tpu_kernel(
    lowered_module_asm: bytes,
    out_type: Any,
    *,
    cost_estimate: CostEstimate | None = None,
    needs_hlo_passes: bool = False,
    needs_layout_passes: bool = False,
    device_type: str | None = None,
    has_communication: bool = False,
    has_custom_barrier: bool = False,
    kernel_name: str | None = None,
    kernel_regeneration_metadata: bytes | None = None,
    vmem_limit_bytes: int | None = None,
    flags: dict[str, bool | int | float] | None = None,
    allow_input_fusion: list[bool] | None = None,
    input_output_aliases: tuple[tuple[int, int], ...] = (),
    serialization_format: int | None = 1,
):
  """Turns a low-level MLIR Mosaic kernel into a JAX-compatible function."""
  unpack = False
  if not isinstance(out_type, collections.abc.Iterable):
    out_type = (out_type,)
    unpack = True
  out_avals = tuple(core.ShapedArray(ty.shape, ty.dtype) for ty in out_type)
  def apply_kernel(*args, collective_id: int | None = None):
    if has_custom_barrier:
      if collective_id is None:
        raise ValueError(
            "collective_id has to be specified when using a custom barrier"
        )
    elif collective_id is not None:
      raise ValueError(
          "collective_id has to be unspecified or None when not using a custom"
          " barrier"
      )
    config = CustomCallBackendConfig(
        lowered_module_asm,
        has_communication,
        collective_id,
        device_type,
        cost_estimate,
        needs_hlo_passes,
        needs_layout_passes,
        vmem_limit_bytes,
        flags,
        allow_input_fusion,
        serialization_format=serialization_format,
    )
    result = tpu_custom_call_p.bind(
        *args,
        config=config,
        kernel_name=kernel_name,
        kernel_regeneration_metadata=kernel_regeneration_metadata,
        out_avals=out_avals,
        input_output_aliases=input_output_aliases,
    )
    return result[0] if unpack else result
  return jax.jit(apply_kernel, static_argnames=["collective_id"])


def dump_mlir(module: ir.Module, name: str):
  """A helper function to dump mosaic mlir module"""
  try:
    should_dump = FLAGS["xla_mosaic_dump_to"].value
  except KeyError:
    return
  if should_dump == "sponge":
    outdir = os.environ.get("TEST_UNDECLARED_OUTPUTS_DIR", None)
    if outdir:
      path = os.path.join(outdir, f"{time.time_ns()}-mosaic-dump-{name}.txt")
      with open(path, "w") as f:
        f.write(str(module))
